import os
import json
import argparse
import pandas as pd
from collections import OrderedDict
from common.flatten_object import flatten_object


def collect_training_reports(root_path, report_name):
    res = OrderedDict()
    for subdir, _, files in os.walk(root_path):
        if report_name in files:
            fullname = os.path.join(subdir, report_name)
            with open(fullname) as fp:
                report = json.load(fp)
                subdir_basename = os.path.basename(subdir)
                res[subdir_basename] = report
    return res


def generate_summary(reports, filename, detailed=False):
    summary = pd.DataFrame()
    for key, report in reports.items():
        report['regime'] = str(report['regime'])
        if not detailed and report['pruning'] is not None and 'params' in report['pruning']:
            del report['pruning']['params']
        report = flatten_object(report)
        df = pd.DataFrame([report.values()], columns=report.keys())
        summary = summary.append(df, ignore_index=True)

    # set ordering
    cols = summary.columns.tolist()
    order = ['model', 'dataset', 'pruning.model_n_params', 'pruning.description', 'pruning.pruner_type', 'epochs',
             'fine-tune', 'regime', 'top1', 'top5', 'pruning.fraction_zeros', 'pruning.fraction_non_zeros',
             'pruning.infer_speedup', 'pruning.train_speedup']
    new_cols = [c for c in order if c in cols]
    cols = new_cols + list(set(cols) - set(new_cols))
    summary = summary[cols]
    summary = summary.drop_duplicates()
    summary.to_excel(filename)


def get_args():
    parser = argparse.ArgumentParser(description='Generate trainings summary')
    parser.add_argument('--results-dir', default='./results', help='base path for training results')
    parser.add_argument('--report-name', default='report.json', help='collected files that are added to summary')
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    path = args.results_dir
    if not os.path.exists(path):
        print('Invalid path={} - exiting'.format(path))
        return

    reports = collect_training_reports(path, args.report_name)
    if not reports:
        print('No reports found at subdirectories of {} - exiting'.format(path))
        return

    filename = os.path.join(path, 'training_summary.xls')
    generate_summary(reports, filename)
    print('Report generated at {}'.format(filename))


if __name__ == '__main__':
    main()

